# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower client interceptor."""


import base64
import collections
from logging import WARNING
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import grpc
from cryptography.hazmat.primitives.asymmetric import ec

from flwr.common.logger import log
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
    bytes_to_public_key,
    compute_hmac,
    generate_shared_key,
    public_key_to_bytes,
)
from flwr.proto.fleet_pb2 import (  # pylint: disable=E0611
    CreateNodeRequest,
    DeleteNodeRequest,
    PingRequest,
    PullTaskInsRequest,
    PushTaskResRequest,
)
from flwr.proto.run_pb2 import GetRunRequest  # pylint: disable=E0611

_PUBLIC_KEY_HEADER = "public-key"
_AUTH_TOKEN_HEADER = "auth-token"

Request = Union[
    CreateNodeRequest,
    DeleteNodeRequest,
    PullTaskInsRequest,
    PushTaskResRequest,
    GetRunRequest,
    PingRequest,
]


def _get_value_from_tuples(
    key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
) -> bytes:
    value = next((value for key, value in tuples if key == key_string), "")
    if isinstance(value, str):
        return value.encode()

    return value


class _ClientCallDetails(
    collections.namedtuple(
        "_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
    ),
    grpc.ClientCallDetails,  # type: ignore
):
    """Details for each client call.

    The class will be passed on as the first argument in continuation function.
    In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
    """


class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor):  # type: ignore
    """Client interceptor for client authentication."""

    def __init__(
        self,
        private_key: ec.EllipticCurvePrivateKey,
        public_key: ec.EllipticCurvePublicKey,
    ):
        self.private_key = private_key
        self.public_key = public_key
        self.shared_secret: Optional[bytes] = None
        self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
        self.encoded_public_key = base64.urlsafe_b64encode(
            public_key_to_bytes(self.public_key)
        )

    def intercept_unary_unary(
        self,
        continuation: Callable[[Any, Any], Any],
        client_call_details: grpc.ClientCallDetails,
        request: Request,
    ) -> grpc.Call:
        """Flower client interceptor.

        Intercept unary call from client and add necessary authentication header in the
        RPC metadata.
        """
        metadata = []
        postprocess = False
        if client_call_details.metadata is not None:
            metadata = list(client_call_details.metadata)

        # Always add the public key header
        metadata.append(
            (
                _PUBLIC_KEY_HEADER,
                self.encoded_public_key,
            )
        )

        if isinstance(request, CreateNodeRequest):
            postprocess = True
        elif isinstance(
            request,
            (
                DeleteNodeRequest,
                PullTaskInsRequest,
                PushTaskResRequest,
                GetRunRequest,
                PingRequest,
            ),
        ):
            if self.shared_secret is None:
                raise RuntimeError("Failure to compute hmac")

            metadata.append(
                (
                    _AUTH_TOKEN_HEADER,
                    base64.urlsafe_b64encode(
                        compute_hmac(
                            self.shared_secret, request.SerializeToString(True)
                        )
                    ),
                )
            )

        client_call_details = _ClientCallDetails(
            client_call_details.method,
            client_call_details.timeout,
            metadata,
            client_call_details.credentials,
        )

        response = continuation(client_call_details, request)
        if postprocess:
            server_public_key_bytes = base64.urlsafe_b64decode(
                _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
            )

            if server_public_key_bytes != b"":
                self.server_public_key = bytes_to_public_key(server_public_key_bytes)
            else:
                log(WARNING, "Can't get server public key, SuperLink may be offline")

            if self.server_public_key is not None:
                self.shared_secret = generate_shared_key(
                    self.private_key, self.server_public_key
                )

        return response
